{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "346198c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 4300)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "zidian = {}\n",
    "with open('./data/msr_paraphrase/zidian.txt') as fr:\n",
    "    for line in fr.readlines():\n",
    "        k, v = line.split(' ')\n",
    "        zidian[k] = int(v)\n",
    "\n",
    "zidian['<PAD>'], len(zidian)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1cf830ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2000"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "#定义数据\n",
    "class MsrDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        self.data = pd.read_csv('./data/msr_paraphrase/数字化数据.txt', nrows=2000)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.data.iloc[i]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "\n",
    "len(MsrDataset())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a3cc3e6b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[   1,    3,   12, 1794, 1795, 1759, 1796,  293,  574,   31,  730,   26,\n",
       "           144,    3, 1797, 1298, 1798,    2,    0,    0,    0,    0,    0,    0,\n",
       "             0,    0,    0,    0,    0,    0],\n",
       "         [   1,  242, 3381,  103,  318,  103,  352, 3636, 1862, 3108,  352,    3,\n",
       "            38, 2072,  650,   31, 3747, 3294,   22,  650,   17, 1080,  144, 1813,\n",
       "           115,    3,    2,    0,    0,    0],\n",
       "         [   1,  675,  105, 2111,  352, 2444, 2445,    3,   66,  117,  648,   31,\n",
       "          1632, 2446, 1110,   17,   56, 2261,  625,  390,    2,    0,    0,    0,\n",
       "             0,    0,    0,    0,    0,    0],\n",
       "         [   1,   33,  115,   56,  376, 2865,  112,   12, 1331,   31,   72,  461,\n",
       "           289,  446,   19,  203,    2,    0,    0,    0,    0,    0,    0,    0,\n",
       "             0,    0,    0,    0,    0,    0],\n",
       "         [   1,  778,   14,  706,    3,  112,  101,  850, 1203,   56,    3,  195,\n",
       "             3,    3, 3371,   84,  170,    3,   31,    3,   18,   38,   12, 3440,\n",
       "             2,    0,    0,    0,    0,    0]]), torch.Size([16, 30]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def to_tensor(data):\n",
    "    b = len(data)\n",
    "    #N句话,每句话30个词\n",
    "    xs = np.zeros((b * 2,30))\n",
    "\n",
    "    for i in range(b):\n",
    "        same, s1, s2 = data[i]\n",
    "\n",
    "        #添加首尾符号,补0到统一长度\n",
    "        s1 = [zidian['<SOS>']] + s1.split(',')[:28] + [\n",
    "            zidian['<EOS>']\n",
    "        ] + [zidian['<PAD>']] * 28\n",
    "        xs[i] = s1[:30]\n",
    "\n",
    "        s2 = [zidian['<SOS>']] + s2.split(',')[:28] + [\n",
    "            zidian['<EOS>']\n",
    "        ] + [zidian['<PAD>']] * 28\n",
    "        xs[b + i] = s2[:30]\n",
    "\n",
    "    return torch.LongTensor(xs)\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "def get_dataloader():\n",
    "    dataloader = DataLoader(dataset=MsrDataset(),\n",
    "                            batch_size=8,\n",
    "                            shuffle=True,\n",
    "                            drop_last=True,\n",
    "                            collate_fn=to_tensor)\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "for i, data in enumerate(get_dataloader()):\n",
    "    sample = data\n",
    "    break\n",
    "\n",
    "sample[:5], sample.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8ff826c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, torch.Size([16, 29, 4300]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ForwardBackward(nn.Module):\n",
    "    def __init__(self, flip):\n",
    "        super().__init__()\n",
    "\n",
    "        self.rnn1 = nn.LSTM(input_size=256, hidden_size=256, batch_first=True)\n",
    "        self.rnn2 = nn.LSTM(input_size=256, hidden_size=256, batch_first=True)\n",
    "\n",
    "        self.fc = nn.Linear(in_features=256, out_features=4300)\n",
    "\n",
    "        self.flip = flip\n",
    "\n",
    "    def forward(self, x):\n",
    "        b = x.shape[0]\n",
    "\n",
    "        #初始化记忆\n",
    "        h = torch.zeros(1, b,256)\n",
    "        c = torch.zeros(1, b,256)\n",
    "\n",
    "        #顺序运算,维度不变\n",
    "        #[16,29,256] -> [16,29,256]\n",
    "\n",
    "        #如果是反向传播,把x逆序,由下面一个矩阵,变成下面第二个矩阵.\n",
    "        '''\n",
    "        [[1,2,3],\n",
    "         [4,5,6]]\n",
    "         \n",
    "        [[3,2,1],\n",
    "         [6,5,4]]'''\n",
    "        if self.flip:\n",
    "            x = torch.flip(x, dims=(1, ))\n",
    "\n",
    "        out1, (h, c) = self.rnn1(x, (h, c))\n",
    "        out2, (h, c) = self.rnn2(out1, (h, c))\n",
    "\n",
    "        #逆序后的x,计算出来的结果也是逆序的,把他们翻转回来\n",
    "        if self.flip:\n",
    "            x = torch.flip(x, dims=(1, ))\n",
    "            out1 = torch.flip(out1, dims=(1, ))\n",
    "            out2 = torch.flip(out2, dims=(1, ))\n",
    "\n",
    "        #全连接输出\n",
    "        #[16,29,256] -> [16,29,4300]\n",
    "        out3 = self.fc(out2)\n",
    "\n",
    "        return x, out1, out2, out3\n",
    "\n",
    "\n",
    "x = torch.FloatTensor(16,29,256)\n",
    "out = ForwardBackward(flip=True)(x)\n",
    "len(out), out[-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5f709554",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 4, torch.Size([16, 29, 4300]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ELMo(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embed = nn.Embedding(num_embeddings=4300,\n",
    "                                  embedding_dim=256,\n",
    "                                  padding_idx=0)\n",
    "\n",
    "        self.fw = ForwardBackward(flip=False)\n",
    "        self.bw = ForwardBackward(flip=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #编码\n",
    "        #[16,30] -> [16,30,256]\n",
    "        x = self.embed(x)\n",
    "\n",
    "        #顺序预测,以当前字预测下一个字,不需要最后一个字\n",
    "        outs_f = self.fw(x[:, :-1, :])\n",
    "\n",
    "        #逆序预测,以当前字预测上一个字,不需要第一个字\n",
    "        outs_b = self.bw(x[:, 1:, :])\n",
    "\n",
    "        return outs_f, outs_b\n",
    "\n",
    "\n",
    "out = ELMo()(sample)\n",
    "len(out), len(out[0]), out[0][-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b2697074",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 8.374743461608887 0.0 0.0\n",
      "0 20 5.009000778198242 0.28663793103448276 0.30603448275862066\n",
      "0 40 4.625175476074219 0.3254310344827586 0.33405172413793105\n",
      "0 60 4.818905830383301 0.2823275862068966 0.28879310344827586\n",
      "0 80 4.759068012237549 0.3793103448275862 0.34913793103448276\n",
      "0 100 3.6306958198547363 0.4956896551724138 0.46120689655172414\n",
      "0 120 3.9946582317352295 0.44612068965517243 0.4224137931034483\n",
      "0 140 3.788682460784912 0.46120689655172414 0.43103448275862066\n",
      "0 160 3.794642448425293 0.4375 0.40301724137931033\n",
      "0 180 3.9886584281921387 0.4525862068965517 0.41594827586206895\n",
      "0 200 3.148937463760376 0.5366379310344828 0.5129310344827587\n",
      "0 220 4.02374792098999 0.4396551724137931 0.39870689655172414\n",
      "0 240 4.113767147064209 0.4245689655172414 0.3900862068965517\n"
     ]
    }
   ],
   "source": [
    "model = ELMo()\n",
    "opt = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "loss_func = nn.CrossEntropyLoss()\n",
    "\n",
    "for epoch in range(1):\n",
    "    for i, x in enumerate(get_dataloader()):\n",
    "        #x = [b,30]\n",
    "        opt.zero_grad()\n",
    "\n",
    "        #模型计算\n",
    "        outs_f, outs_b = model(x)\n",
    "\n",
    "        #在计算loss的时候,只需要全连接输出\n",
    "        #[b,29,4300]\n",
    "        outs_f = outs_f[-1]\n",
    "        outs_b = outs_b[-1]\n",
    "\n",
    "        #正向预测是以当前字预测下一个字,所以计算loss不需要第一个字\n",
    "        #[b,30] -> [b,29]\n",
    "        x_f = x[:, 1:]\n",
    "        #逆向预测是以当前字预测上一个字,所以计算loss不需要最后一个字\n",
    "        #[b,30] -> [b,29]\n",
    "        x_b = x[:, :-1]\n",
    "\n",
    "        #打平,不然计算不了loss\n",
    "        #[b,29,4300] -> [b*29,4300]\n",
    "        outs_f = outs_f.reshape(-1, 4300)\n",
    "        outs_b = outs_b.reshape(-1, 4300)\n",
    "        #[b,29] -> [b*29]\n",
    "        x_f = x_f.reshape(-1)\n",
    "        x_b = x_b.reshape(-1)\n",
    "\n",
    "        #分别计算全向和后向的loss,再求和作为loss\n",
    "        loss_f = loss_func(outs_f, x_f)\n",
    "        loss_b = loss_func(outs_b, x_b)\n",
    "        loss = (loss_f + loss_b) / 2\n",
    "\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "        if i % 20 == 0:\n",
    "            #统计正确率\n",
    "            correct_f = (x_f == outs_f.argmax(axis=1)).sum().item()\n",
    "            correct_b = (x_b == outs_b.argmax(axis=1)).sum().item()\n",
    "            total = x.shape[0] * 29\n",
    "            print(epoch, i, loss.item(), correct_f / total, correct_b / total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7ad4ba70",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 28, 512])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_emb(x):\n",
    "    #模型运算\n",
    "    outs_f, outs_b = model(x)\n",
    "\n",
    "    #在词向量编码时,可以任意选择一层的输出\n",
    "    #[16,29,256]\n",
    "    outs_f = outs_f[1]\n",
    "    outs_b = outs_b[1]\n",
    "\n",
    "    #正向和反向的输出不能对齐,把他们重叠的部分截取出来\n",
    "    #[16,28,256]\n",
    "    outs_f = outs_f[:, 1:]\n",
    "    outs_b = outs_b[:, :-1]\n",
    "\n",
    "    #拼合在一起,就是编码结果了\n",
    "    #[16,28,256 + 256]\n",
    "    embed = torch.cat((outs_f, outs_b), dim=2)\n",
    "\n",
    "    #[16,28,512]\n",
    "    return embed\n",
    "\n",
    "\n",
    "get_emb(sample).shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
